#!/usr/bin/env python
import sys
import requests

SOURCE1_NAME = "mysql-01"
SOURCE2_NAME = "mysql-02"
WORKER1_NAME = "worker1"
WORKER2_NAME = "worker2"


API_ENDPOINT = "http://127.0.0.1:8261/api/v1/sources"
API_ENDPOINT_NOT_LEADER = "http://127.0.0.1:8361/api/v1/sources"


def create_source_failed():
    resp = requests.post(url=API_ENDPOINT)
    assert resp.status_code == 400
    print("create_source_failed resp=", resp.json())


def create_source1_success():
    req = {
        "case_sensitive": False,
        "enable_gtid": False,
        "host": "127.0.0.1",
        "password": "123456",
        "port": 3306,
        "source_name": SOURCE1_NAME,
        "user": "root",
    }
    resp = requests.post(url=API_ENDPOINT, json=req)
    print("create_source1_success resp=", resp.json())
    assert resp.status_code == 201


def create_source2_success():
    req = {
        "case_sensitive": False,
        "enable_gtid": False,
        "host": "127.0.0.1",
        "password": "123456",
        "port": 3307,
        "source_name": SOURCE2_NAME,
        "user": "root",
    }
    resp = requests.post(url=API_ENDPOINT, json=req)
    print("create_source1_success resp=", resp.json())
    assert resp.status_code == 201


def list_source_success(source_count):
    resp = requests.get(url=API_ENDPOINT)
    assert resp.status_code == 200
    data = resp.json()
    print("list_source_by_openapi_success resp=", data)
    assert data["total"] == int(source_count)


def list_source_with_status_success(source_count, status_count):
    resp = requests.get(url=API_ENDPOINT + "?with_status=true")
    assert resp.status_code == 200
    data = resp.json()
    print("list_source_with_status_success resp=", data)
    assert data["total"] == int(source_count)
    for i in range(int(source_count)):
        assert len(data["data"][i]["status_list"]) == int(status_count)


def list_source_with_redirect(source_count):
    resp = requests.get(url=API_ENDPOINT_NOT_LEADER)
    assert resp.status_code == 200
    data = resp.json()
    print("list_source_by_openapi_redirect resp=", data)
    assert data["total"] == int(source_count)


def delete_source_success(source_name):
    resp = requests.delete(url=API_ENDPOINT + "/" + source_name)
    assert resp.status_code == 204
    print("delete_source_success")


def delete_source_with_force_success(source_name):
    resp = requests.delete(url=API_ENDPOINT + "/" + source_name + "?force=true")
    assert resp.status_code == 204
    print("delete_source_with_force_success")


def delete_source_failed(source_name):
    resp = requests.delete(url=API_ENDPOINT + "/" + source_name)
    print("delete_source_failed msg=", resp.json())
    assert resp.status_code == 400


def start_relay_failed(source_name, worker_name):
    url = API_ENDPOINT + "/" + source_name + "/start-relay"
    req = {
        "worker_name_list": [worker_name],
    }
    resp = requests.post(url=url, json=req)
    print("start_relay_failed resp=", resp.json())
    assert resp.status_code == 400


def start_relay_success(source_name, worker_name):
    url = API_ENDPOINT + "/" + source_name + "/start-relay"
    req = {
        "worker_name_list": [worker_name],
        "purge": {"interval": 3600, "expires": 0, "remain_space": 15},
    }
    resp = requests.post(url=url, json=req)
    assert resp.status_code == 200


def pause_relay_success(source_name):
    url = API_ENDPOINT + "/" + source_name + "/pause-relay"
    resp = requests.post(url=url)
    assert resp.status_code == 200


def resume_relay_success(source_name):
    url = API_ENDPOINT + "/" + source_name + "/resume-relay"
    resp = requests.post(url=url)
    assert resp.status_code == 200


def start_relay_success_with_two_worker(source_name, worker1_name, worker2_name):
    url = API_ENDPOINT + "/" + source_name + "/start-relay"
    req = {
        "worker_name_list": [worker1_name, worker2_name],
        "purge": {"interval": 3600, "expires": 0, "remain_space": 15},
    }
    resp = requests.post(url=url, json=req)
    assert resp.status_code == 200


def get_source_status_failed(source_name):
    url = API_ENDPOINT + "/" + source_name + "/status"
    resp = requests.get(url=url)
    print("get_source_status_failed resp=", resp.json())
    assert resp.status_code == 400


def get_source_status_success(source_name, status_count=1):
    url = API_ENDPOINT + "/" + source_name + "/status"
    resp = requests.get(url=url)
    assert resp.status_code == 200
    print("get_source_status_success resp=", resp.json())
    assert int(status_count) == resp.json()["total"]


def get_source_status_success_with_relay(source_name, source_idx=0):
    url = API_ENDPOINT + "/" + source_name + "/status"
    resp = requests.get(url=url)
    assert resp.status_code == 200
    res = resp.json()
    print("get_source_status_success_with_relay resp=", res)
    assert res["data"][int(source_idx)]["relay_status"] is not None


def get_source_status_success_no_relay(source_name, source_idx=0):
    url = API_ENDPOINT + "/" + source_name + "/status"
    resp = requests.get(url=url)
    assert resp.status_code == 200
    res = resp.json()
    print("get_source_status_success_no_relay resp=", res)
    assert res["data"][int(source_idx)].get("relay_status") is None


def stop_relay_failed(source_name, worker_name):
    url = API_ENDPOINT + "/" + source_name + "/stop-relay"
    req = {
        "worker_name_list": [worker_name],
    }
    resp = requests.post(url=url, json=req)
    print("stop_relay_failed resp=", resp.json())
    assert resp.status_code == 400


def stop_relay_success(source_name, worker_name):
    url = API_ENDPOINT + "/" + source_name + "/stop-relay"
    req = {
        "worker_name_list": [worker_name],
    }
    resp = requests.post(url=url, json=req)
    assert resp.status_code == 200


def transfer_source_success(source_name, worker_name):
    url = API_ENDPOINT + "/" + source_name + "/transfer"
    req = {
        "worker_name": worker_name,
    }
    resp = requests.post(url=url, json=req)
    assert resp.status_code == 200


def get_source_schemas_and_tables_success(source_name, schema_name="", table_name=""):
    schema_url = API_ENDPOINT + "/" + source_name + "/schemas"
    schema_resp = requests.get(url=schema_url)
    assert schema_resp.status_code == 200
    print("get_source_schemas_and_tables_success schema_resp=", schema_resp.json())
    schema_list = schema_resp.json()
    assert schema_name in schema_list

    table_url = API_ENDPOINT + "/" + source_name + "/schemas/openapi"
    table_resp = requests.get(url=table_url)
    print("get_source_schemas_and_tables_success table_resp=", table_resp.json())
    table_list = table_resp.json()
    assert table_name in table_list


if __name__ == "__main__":
    FUNC_MAP = {
        "create_source_failed": create_source_failed,
        "create_source1_success": create_source1_success,
        "create_source2_success": create_source2_success,
        "list_source_success": list_source_success,
        "list_source_with_redirect": list_source_with_redirect,
        "list_source_with_status_success": list_source_with_status_success,
        "delete_source_failed": delete_source_failed,
        "delete_source_success": delete_source_success,
        "delete_source_with_force_success": delete_source_with_force_success,
        "start_relay_failed": start_relay_failed,
        "pause_relay_success": pause_relay_success,
        "resume_relay_success": resume_relay_success,
        "start_relay_success": start_relay_success,
        "start_relay_success_with_two_worker": start_relay_success_with_two_worker,
        "stop_relay_failed": stop_relay_failed,
        "stop_relay_success": stop_relay_success,
        "get_source_status_failed": get_source_status_failed,
        "get_source_status_success": get_source_status_success,
        "get_source_status_success_with_relay": get_source_status_success_with_relay,
        "get_source_status_success_no_relay": get_source_status_success_no_relay,
        "transfer_source_success": transfer_source_success,
        "get_source_schemas_and_tables_success": get_source_schemas_and_tables_success,
    }

    func = FUNC_MAP[sys.argv[1]]
    if len(sys.argv) >= 2:
        func(*sys.argv[2:])
    else:
        func()
